########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import json, math, random, os, sys, string
import numpy as np
import torch
from torch.utils.data import Dataset
from pytorch_lightning.utilities import rank_zero_info
from .binidx import MMapIndexedDataset
from .utils import MaybeIsPrime
from rwkv.utils import PIPELINE
pipeline = PIPELINE('rwkv6', "rwkv_vocab_v20230424")
import logging
from datasets import load_dataset####
from tqdm import tqdm
import copy
from numpy import pad
import pyarrow.parquet as pq##
logging.basicConfig(filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
class MyDataset(Dataset):
    def __init__(self, args):
        self.args = args
        self.index = 0
        self.vocab_size = args.vocab_size
        self.data_size = 1024
        self.lambada = pq.read_table("lambada/train-00000-of-00001.parquet")[0]
        if args.data_type == "binidx":
            
            rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")

            if args.my_pile_version == 1:
                self.data = MMapIndexedDataset(args.data_file)
                self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
                rank_zero_info(f"Data has {self.data_size} tokens.")
            elif args.my_pile_version == 2:
                data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n')
                data_list = [i.strip().split(' ') for i in data_list]
                self.data = []
                self.data_size = int(data_list[-1][-1])
                rank_zero_info(f"Data has {self.data_size} chunks.")
                for d in data_list:
                    data = MMapIndexedDataset(d[0])
                    data_size = len(data._bin_buffer) // data._index._dtype_size
                    assert (data_size - args.ctx_len) == int(d[1])
                    self.data += [[int(d[-1]), int(d[1]), data]]
                # rank_zero_info(self.data)

            if args.my_qa_mask > 0:
                # self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document')
                self.data_pile = MMapIndexedDataset('/fsx/pile_deduped/pile_0.87_deduped_text_document')
                self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
            else:
                self.data_pile = None
                self.data_pile_size = 0

            if args.my_pile_stage > 0:
                # assert self.data_size == 332115325534 and self.vocab_size == 50277
                self.samples_per_epoch = args.epoch_steps * args.real_bsz
                assert self.samples_per_epoch == 40320
                rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
                dataset_slot = self.data_size // args.ctx_len
                if args.my_pile_stage != 4:
                    assert MaybeIsPrime(args.magic_prime)
                    assert args.magic_prime % 3 == 2
                    assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
        elif args.data_type == "numpy":
            self.data = np.load(args.data_file).astype("int")
            self.vocab_size = args.vocab_size
            rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
            self.data_size = len(self.data)
            rank_zero_info(f"Data has {self.data_size} tokens.")
        elif args.data_type == "uint16":
            self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
            self.vocab_size = args.vocab_size
            rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
            self.data_size = self.data.shape[0]
            rank_zero_info(f"Data has {self.data_size} samples.")
        else:
            '''
            if args.data_type == "dummy":
                rank_zero_info("Building dummy data...")
                self.data = ""
                for i in range(100000):
                    aa = (i) % 10000
                    bb = (i * i) % 10000
                    cc = aa + bb
                    self.data += f".{aa}+{bb}={cc}."
            else:
                pass##self.data = open(args.data_file, "r").read()
            rank_zero_info("Building token list...")
            unique = sorted(list(set(self.data)))
            self.vocab_size = 1114112
            # rank_zero_info()
            # for u in unique:
            #     print(u, end=' ')
            # rank_zero_info('\n\n')
            xx = 0
            xxObj = {}
            for codepoint in range(1114112):  # UTF-8 编码范围是 0-1114111
                try:
                    char_ = chr(codepoint)
                    #if char in string.printable:  # 只选择可打印字符
                    xxObj[xx%100] = char_
                    xx += 1
                except ValueError:
                    continue
                if (xx% 100 == 1):
                    xxObj = {}
                    with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-8") as vocab_file:
                        vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
            self.data_size = len(self.data)
            rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
            self.stoi = {ch: i for i, ch in enumerate(unique)}
            self.itos = {i: ch for i, ch in enumerate(unique)}'''
        self.data = load_dataset("json", data_files = "metamath/MetaMathQA-40K.json")  
    def __len__(self):
        return self.args.epoch_steps * self.args.micro_bsz

    def __getitem__(self, idx):
        args = self.args
        rank = self.global_rank
        epoch = self.real_epoch
        world_size = self.world_size
        # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
        devices = int(args.devices)
        if devices>1:
            idx = idx*devices+rank
        
        if args.data_type == "uint16":
            i = np.random.randint(0, self.data_size-1)
            dix = self.data[i]
            x = torch.tensor(dix[:-1], dtype=torch.long)
            y = torch.tensor(dix[1:], dtype=torch.long)
        else:
            i = np.random.randint(0, 39999)
            ctx_len = args.ctx_len
            req_len = ctx_len + 1
            magic_prime = args.magic_prime
            data = self.data

            if args.my_pile_stage > 0:
                ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank

                if args.my_qa_mask > 0:
                    ii_orig = ii
                    if ii % 2 == 0:
                        ii = -1
                        data = self.data_pile
                    else:
                        ii = ii // 2
                if data == self.data_pile:
                    i = np.random.randint(0, self.data_pile_size - req_len)
                else:
                    if args.my_pile_stage == 4 or ii < args.my_random_steps:
                        # cheat: pick a random spot in dataset
                        if args.my_pile_version == 1:
                            i = np.random.randint(0, self.data_size - req_len)
                        else:
                            i = np.random.randint(0, self.data_size)
                    else:
                        ii = ii - args.my_random_steps
                        factor = (math.sqrt(5) - 1) / 2
                        factor = int(magic_prime * factor)
                        i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
                        i = i + args.my_pile_shift
                # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
            else:
                # cheat: pick a random spot in dataset
                ##i = np.random.randint(0, self.data_size - req_len)
                pass
            in_q = torch.ones(2048).unsqueeze(0)
            in_a = torch.ones(2048).unsqueeze(0)
            if False:
                if args.my_pile_version == 1:
                    if args.dataload == 'pad':
                        dix, min_len = data.pad(idx=idx, length=req_len)
                    elif args.dataload == 'only':
                        dix = data.only(idx=idx, length=req_len).astype(int)
                    else:
                        dix = data.get(idx=0, offset=i, length=req_len).astype(int)
                else:
                    # self.data : cutoff, chunk_count, data
                    for j in range(len(data)):
                        if i < data[j][0]:
                            ii = i
                            i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1]
                            dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int)
                            # print(ii, j, i)
                            break
            elif args.data_type == "numpy":
                dix = data[i : i + req_len]
            if True:
                #args.data_type == "binidx":
                test_set_n = 0
                sampled_data = []
       

                ds = data##load_dataset("arc-cot/data")  
                i = np.random.randint(0, 4076206)
                
                #for i__ in range(0,4):
                already_len_q = 0
                already_len_a = 0
                x = []
                y = []
                try_ = True
                q_ = ""
                a_ = ""
                qa_num = 0
                while already_len_a < 2020 and already_len_q < 1724 :
                    if (i < 39999):
                        i += 1
                    else:
                        i = np.random.randint(0, 39999)
                    a_ = pipeline.encode("The answer and the chain of thoughts to the question :" + ds['train'][i]['query'] + "is" + ds['train'][i]['response'] ) 
                    q_ = pipeline.encode("Question No."+str(qa_num)+"is"+ds['train'][i]['query'] ) 
                    if (len(a_) + already_len_a < 2046 and len(q_) + already_len_q < 2046):
                        qa_num += 1
                        
                        already_len_a += len(a_)
                        already_len_q += len(q_) 
                        x.extend(q_ )
                        
                        y.extend(a_ )
                        
                        '''if (try_ == False):
                            while already_len_a < 2048:
                                i = np.random.randint(0, 1024)
                                
                                a_ = pipeline.encode(ds['train'][i]['answer'] ) 
                                q_ = pipeline.encode(ds['train'][i]['question'] ) 
                                already_len_a += len(a_)
                                already_len_q += len(q_) 
                                x.extend(pipeline.encode(ds['train'][i]['answer'] )  )

                            x = x[:2048]'''
                    else:
                        if (try_):
                            i += 1
                            try_ = False
                        else:
                            break
                
                n_original = 2048 - len(y)
                n = 0
                '''i = np.random.randint(0, 2570)
                if_new_lambada_try = 5
                lambada_data = pipeline.encode(str(self.lambada[i]) ) 
                len_lambada = len(lambada_data)
                if (len_lambada < n_original - 10):
                    ##防止pipeline.encode导致溢出
                    pass
                else:
                    while if_new_lambada_try > 0 :
                        
                        i = np.random.randint(0, 2550)
                        lambada_data = pipeline.encode(str(self.lambada[i]) ) 
                        len_lambada = len(lambada_data)
                        len_lambada = len(str(self.lambada[i] ))
                        if (len_lambada < n_original - 10):
                            n = n_original - len_lambada
                            break
                        if_new_lambada_try -= 1
                if if_new_lambada_try == 0:
                    n = 2048 - len(y) 
                    for h in range(n):
                        y.append(33)
                    n = 2048 - len(x) 
                    for h in range(n):
                        x.append(33)
                else:
                    n = 2048 - len(y) - len_lambada
                    for h in range(n):
                        y.append(33)
                    n = 2048 - len(x) - len_lambada
                    for h in range(n):
                        x.append(33)
                    x.extend(lambada_data)
                
                    y.extend(lambada_data)'''
                n = 2048 - len(y) 
                for h in range(n):
                    y.append(33)
                n = 2048 - len(x) 
                for h in range(n):
                    x.append(33)
                if len(x) > 2048:
                    x = x[:2048]
                    logging.info(f'x大于 2048')
                if len(y) > 2048:
                    y = y[:2048]
                    logging.info(f'y大于 2048')
                in_q = torch.tensor(x)
                in_a = torch.tensor(y)
            #x = pad(x, (0, 2048 - len(x)), 'constant')
            #y = pad(y, (0, 2048 - len(x)), 'constant')
            '''if (len(x) != 2048):
                print("no")
            if (len(y) != 2048):
                print("no_:", y)'''
            return in_q, in_a
            '''
            if args.my_qa_mask == 1:
                if data == self.data_pile:
                    z = [1] * ctx_len
                else:
                    z = [0] * ctx_len
                    z_sum = 0
                    isGood = False
                    for i in range(3, ctx_len):
                        if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
                            isGood = True
                        if dix[i] == 0:
                            isGood = False
                        if isGood:
                            z[i] = 1
                            z_sum += 1
                    if z_sum == 0:
                        z = [1] * ctx_len
                        i = np.random.randint(0, self.data_pile_size - req_len)
                        dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
                z = torch.tensor(z, dtype=torch.bfloat16)
            dix = [0]
            x = torch.tensor(dix[:-1], dtype=torch.long)
            y = torch.tensor(dix[1:], dtype=torch.long)

            # if ii_orig < 50:
            #     # if rank == 1:
            #     print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
            # else:
            #     exit(0)

            if args.my_qa_mask == 1:
                return x, y, z
            if args.loss_mask=='qa':

                t1 = pipeline.encode('question:')
                t2 = pipeline.encode('answer:')
                mask = self.create_mask(dix, t1, t2, min_len)
                return x, y, mask
            
            if args.loss_mask=='pad':
                mask = torch.zeros(req_len-1)
                mask[:min_len-1] = 1
                return x, y, mask
                

            return x, y'''
        
    def create_mask(self, seq, token1, token2, min_len):
        # 找到所有特殊标记的索引
        indices1 = []
        for i in range(min_len - len(token1) + 1):
            if np.array_equal(seq[i:i + len(token1)], token1):
                indices1.append(i)
        indices2 = []

        for i in range(min_len - len(token2) + 1):
            if np.array_equal(seq[i:i + len(token2)], token2):
                indices2.append(i)
        mask = torch.zeros(seq.shape)
        #assert len(indices2)!=0 and len(indices1)!=0
        select = 0
        for i in range(min_len):
            if i in indices1:
                select = 0
            elif i in indices2:
                select = 1
            mask[i] = select
        if torch.sum(mask)==0:
            mask[:min_len-1] = 1
        return mask[1:]

